{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit ('gym': conda)",
   "metadata": {
    "interpreter": {
     "hash": "129f2f7391b777d52a08fa9ff141bedb973b18bd5660445398f73676c694c50c"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import time\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from itertools import count\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from entities.network import DQN\n",
    "from entities.environment import EnvironmentManager\n",
    "from entities.strategy import EpsilonGreedyStrategy\n",
    "from entities.memory import ReplayMemory\n",
    "from entities.experience import Experience\n",
    "from entities.q_values import QValues\n",
    "from entities.agent import Agent, MODES\n",
    "from entities.utils import plot, create_torch_device, extract_tensors\n",
    "\n",
    "from entities.constants import POLICY_NET_FILE"
   ]
  },
  {
   "source": [
    "## Example of non-processed screen V.S processed screen"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = create_torch_device()\n",
    "env_manager = EnvironmentManager(device)\n",
    "\n",
    "# Render first screen\n",
    "non_processed = env_manager.render('rgb_array')\n",
    "processed = env_manager.get_processed_screen()\n",
    "\n",
    "# Setup the graphs\n",
    "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
    "axs[0].imshow(non_processed)\n",
    "axs[0].set_title('Non-processed screen')\n",
    "\n",
    "axs[1].imshow(processed.squeeze(0).permute(1, 2, 0), interpolation='none')\n",
    "axs[1].set_title('Processed screen')\n",
    "plt.show()"
   ]
  },
  {
   "source": [
    "## Example of starting state"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "screen = env_manager.get_state()\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none')\n",
    "plt.title('Starting screen')\n",
    "plt.show()"
   ]
  },
  {
   "source": [
    "## Example of non starting states"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    env_manager.take_action(torch.tensor([1]))\n",
    "screen = env_manager.get_state()\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none')\n",
    "plt.title('Starting screen')\n",
    "plt.show()"
   ]
  },
  {
   "source": [
    "## Example of end state"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_manager.done = True\n",
    "screen = env_manager.get_state()\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(screen.squeeze(0).permute(1, 2, 0), interpolation='none')\n",
    "plt.title('Starting screen')\n",
    "plt.show()"
   ]
  },
  {
   "source": [
    "Example of the plot method"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(np.random.rand(300), 100)"
   ]
  },
  {
   "source": [
    "# Main training loop"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define hyper parameters\n",
    "batch_size = 256\n",
    "gamma = 0.999\n",
    "eps_start = 1\n",
    "eps_end = 0.01\n",
    "eps_decay_rate = 0.001\n",
    "target_update = 10\n",
    "memory_size = 100000\n",
    "learning_rate = 0.001\n",
    "num_episodes = 1000\n",
    "\n",
    "# Define main components\n",
    "device = create_torch_device()\n",
    "env_manager = EnvironmentManager(device)\n",
    "strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay_rate)\n",
    "agent = Agent(strategy, env_manager.num_actions_available(), device)\n",
    "memory = ReplayMemory(memory_size)\n",
    "\n",
    "# Define the neural networks\n",
    "policy_net = DQN(env_manager.get_screen_width(), env_manager.get_screen_height()).to(device)\n",
    "target_net = DQN(env_manager.get_screen_width(), env_manager.get_screen_height()).to(device)\n",
    "target_net.load_state_dict(policy_net.state_dict())\n",
    "target_net.eval()\n",
    "optimizer = optim.Adam(params=policy_net.parameters(), lr=learning_rate)\n",
    "\n",
    "# Store all the durations in this array\n",
    "episode_durations = []\n",
    "\n",
    "# Start iterating the episodes\n",
    "for i_episode in range(num_episodes):\n",
    "    # Reset the environment and get the first state\n",
    "    env_manager.reset()\n",
    "    state = env_manager.get_state()\n",
    "\n",
    "    # For each episode, play the game\n",
    "    for timestep in count():\n",
    "        # Make the agent select an action\n",
    "        action = agent.select_action(state, policy_net)\n",
    "        reward = env_manager.take_action(action)\n",
    "        next_state = env_manager.get_state()\n",
    "\n",
    "        # Store the experience in the memory\n",
    "        memory.push(Experience(state, action, next_state, reward))\n",
    "        state = next_state\n",
    "\n",
    "        # If we have enough experiences, start optimizing\n",
    "        if memory.can_sample_memory(batch_size):\n",
    "            experiences = memory.sample(batch_size)\n",
    "            states, actions, rewards, next_states = extract_tensors(experiences)\n",
    "\n",
    "            current_q_values = QValues.get_current(policy_net, states, actions)\n",
    "            next_q_values = QValues.get_next(target_net, next_states)\n",
    "            target_q_values = next_q_values * gamma + rewards\n",
    "\n",
    "            loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        \n",
    "        if env_manager.done:\n",
    "            episode_durations.append(timestep)\n",
    "            plot(episode_durations, 100)\n",
    "            break\n",
    "    \n",
    "    if i_episode % target_update == 0:\n",
    "        target_net.load_state_dict(policy_net.state_dict())\n",
    "\n",
    "env_manager.close() \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the models for inference\n",
    "data_path = os.path.join(os.getcwd(), 'models')\n",
    "policy_net_file = os.path.join(data_path, 'policy_network.pth')\n",
    "target_net_file = os.path.join(data_path, 'target_network.pth')\n",
    "\n",
    "torch.save(policy_net.state_dict(), policy_net_file)\n",
    "torch.save(target_net.state_dict(), target_net_file)"
   ]
  }
 ]
}